Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Jan 28, 2026

Description

This PR fuses pre-swizzling into the grouped MXFP8 quantization kernel so that scaling factors are stored in the format expected by GEMM.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a template parameter to the kernel to control the scaling-factor format.
  • Added a new member to GroupedTensor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 28, 2026

Greptile Overview

Greptile Summary

This PR adds optional pre-swizzling support to the grouped MXFP8 quantization kernel, allowing scaling factors to be stored in the format expected by GEMM operations. The implementation adds a new template parameter WITH_GEMM_SWIZZLED_SCALES to the kernel and uses compile-time branching to select between standard linear indexing and GEMM-swizzled indexing for both rowwise and colwise scaling paths. A new with_gemm_swizzled_scales boolean field was added to the GroupedTensor struct to control this behavior at runtime.

Key Changes:

  • Added WITH_GEMM_SWIZZLED_SCALES template parameter to the group_quantize_mxfp8_kernel
  • Conditional scale index computation using gemm_swizzled_scale_idx() function for both colwise (line 484) and rowwise (line 617) scaling paths
  • Added with_gemm_swizzled_scales field to GroupedTensor struct with proper initialization
  • Wrapped kernel instantiation with TRANSFORMER_ENGINE_SWITCH_CONDITION macro to generate both swizzled and non-swizzled versions

Confidence Score: 5/5

  • This PR is safe to merge
  • The implementation follows the exact same pattern as existing swizzling code in quantize_mxfp8.cuh and gated_mxfp8.cuh. The template parameter approach ensures zero runtime overhead. The GroupedTensor field is properly initialized in both constructor and clear() method. The code is well-structured and consistent with the codebase
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Added WITH_GEMM_SWIZZLED_SCALES template parameter and conditional scale index computation using gemm_swizzled_scale_idx for both rowwise and colwise scaling paths
transformer_engine/common/common.h Added with_gemm_swizzled_scales boolean field to GroupedTensor struct with proper initialization in constructor and clear() method

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_preswizzle branch from bf07d9d to ed61ff7 Compare February 11, 2026 19:27
@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Feb 11, 2026
@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +482 to 489
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type inconsistency: scale_idx is declared as size_t in the colwise path but as int in the rowwise path (line 615). Should use consistent type (size_t) in both paths.

Suggested change
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}

@ksivaman ksivaman merged commit 93d51c8 into NVIDIA:main Feb 12, 2026
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants